/**
 * Licensed to Big Data Genomics (BDG) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The BDG licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.bdgenomics.adam.ds.read.recalibration

import org.bdgenomics.adam.util.PhredUtils
import scala.math.{ exp, log }

/**
 * A table containing the final error covariate to recalibrated phred mappings.
 *
 * This is generated by collecting the error covariates in an observation table,
 * which is then turned into a TempRecalibrationTable, which is then inverted.
 * This is necessitated by the hierarchical process through which the final
 * quality scores are calculated.
 *
 * @param table A table mapping covariate keys to Phred scores, with the Phred
 *   scores encoded in ASCII using Illumina (33) encodings.
 */
private[adam] case class RecalibrationTable private[recalibration] (
    private val table: Map[CovariateKey, Char]) {

  /**
   * @param covariates The covariates corresponding to all of the bases in a read.
   * @return Returns an array of Phred-scaled quality scores representing the
   *   unfiltered quality scores in a read after recalibration. These scores are
   *   unfiltered in the sense that the read may decide to omit the recalibrated
   *   scores if the base was a low quality base.
   */
  def apply(covariates: Array[CovariateKey]): Array[Char] = {
    val numCovariates = covariates.length
    val newQuals = new Array[Char](numCovariates)

    var idx = 0
    while (idx < numCovariates) {
      val key = covariates(idx)
      newQuals(idx) = table.getOrElse(key.toDefault, key.qualityScore)
      idx += 1
    }
    newQuals
  }
}

/**
 * A temporary table to be used for generating the BQSR recalibration table.
 *
 * In the prior implementation of BQSR, this class was the recalibration table.
 * However, the downside of that implementation was the covariate &rarr; quality
 * mapping required a recursive calculation inside an inner loop many times.
 * loop many times. Given that we expect the number of bases to be orders and
 * orders of magintude larger than the number of covariate bins, we've
 * refactored the calculation so that we generate this temporary table, which we
 * then query using all the known covariate keys. The result of these queries
 * go into a new RecalibrationTable class, which does a constant lookup in a
 * map. Additionally, a virtue of this approach is that we don't need to throw
 * out any of the code used to compute the recalibration scores.
 *
 * @param table A table mapping read groups to error aggregates.
 * @param maxQualScore The maximum quality score to recalibrate to.
 */
private case class TempRecalibrationTable(
    val tempTable: Map[Int, (Aggregate, QualityTable)],
    val maxQualScore: Int = 50) {

  private val maxLogP = log(PhredUtils.phredToErrorProbability(maxQualScore))

  /**
   * Looks up the recalibrated Phred of an error covariate by walking the table
   * hierarchy for that covariate.
   *
   * @param key The error covariate to look up.
   * @return Returns an ASCII/Illumina (33) encoded quality score for this error
   *   covariate.
   */
  def lookup(key: CovariateKey): Char = {
    val globalEntry = tempTable(key.readGroupId)
    val globalDelta = computeGlobalDelta(globalEntry._1)
    val residueLogP = log(PhredUtils.phredToErrorProbability(key.qualityScore.toInt - 33))
    val qualityEntry = getQualityEntry(key.qualityScore, globalEntry)
    val qualityDelta = computeQualityDelta(qualityEntry, residueLogP + globalDelta)
    val extrasDelta = computeExtrasDelta(qualityEntry,
      key,
      residueLogP + globalDelta + qualityDelta)
    val correctedLogP = residueLogP + globalDelta + qualityDelta + extrasDelta
    qualityFromLogP(correctedLogP)
  }

  /**
   * @param logP A log probability.
   * @return This probability as a Phred scaled, ASCII/Illumina 33 encoded char.
   */
  private def qualityFromLogP(logP: Double): Char = {
    val boundedLogP = math.min(0.0, math.max(maxLogP, logP))
    (PhredUtils.errorProbabilityToPhred(exp(boundedLogP)) + 33).toChar
  }

  /**
   * @param globalEntry The aggregated global empirical error estimate for a
   *   single read group.
   * @return Returns the log scaled delta between the empirical error rate and
   *   the error rate predicted from the quality scores of all the observed
   *   bases.
   */
  private def computeGlobalDelta(
    globalEntry: Aggregate): Double = {
    log(globalEntry.bayesianErrorProbability()) - log(globalEntry.reportedErrorProbability)
  }

  /**
   * @param quality The quality score to look up in the global table for a read
   *   group.
   * @param globalEntry The entry for this read group in the global table.
   * @return Returns the error aggregate and extra covariates table
   *   corresponding to a single base with a given quality score in a given
   *   read group.
   */
  private def getQualityEntry(
    quality: Char,
    globalEntry: (Aggregate, QualityTable)): (Aggregate, ExtrasTable) = {
    globalEntry._2.table(quality)
  }

  /**
   * @param qualityEntry The error aggregate/extra covariate pair corresponding
   *   to a single base with a given quality score in a single read group.
   * @param offset The log error probability of this quality bucket, given the
   *   corrected empirical error rate of all bases in this read group.
   * @param The log scaled delta between the predicted and measured error rate
   *   for this bucket.
   */
  private def computeQualityDelta(qualityEntry: (Aggregate, ExtrasTable),
                                  offset: Double): Double = {
    log(qualityEntry._1.bayesianErrorProbability()) - offset
  }

  /**
   * @param qualityEntry The error aggregate/extra covariate pair corresponding
   *   to a single base with a given quality score in a single read group.
   * @param key The covariate key describing the error covariates that this
   *   read base maps into.
   * @param offset The log error probability of this quality bucket, given the
   *   corrected empirical error rate of all bases in this read group.
   * @param The log scaled delta between the predicted and measured error rate
   *   for this bucket.
   */
  private def computeExtrasDelta(qualityEntry: (Aggregate, ExtrasTable),
                                 key: CovariateKey,
                                 offset: Double): Double = {
    def tableContribution(aggregate: Aggregate): Double = {
      log(aggregate.bayesianErrorProbability()) - offset
    }

    // Returns sum(delta for each extra covariate)
    val extrasTables = qualityEntry._2
    (tableContribution(extrasTables.cycleTable(key.cycle)) +
      tableContribution(extrasTables.dinucTable(key.dinuc)))
  }
}

private[recalibration] object RecalibrationTable {

  /**
   * Generates a recalibration table by mapping an observation table
   * into the hierarchical structure used by BQSR and then inverting said table.
   *
   * @param observed The observed covariates along with their empirical error
   *   rate estimates.
   * @return Returns a fully inverted recalibration table.
   */
  def apply(observed: ObservationTable): RecalibrationTable = {
    val globalTable: Map[Int, (Aggregate, QualityTable)] = observed.entries
      .groupBy(_._1.readGroupId)
      .map(globalEntry => {
        (globalEntry._1,
          (aggregateObservations(globalEntry._2),
            computeQualityTable(globalEntry)))
      })

    // make a temp table to query
    val tt = new TempRecalibrationTable(globalTable)

    // take all the covariates from the observation table, and query
    // them against the recalibration table
    val recalibrationQualityMappings = {
      observed.entries
        .keys
        .map(key => (key.toDefault, tt.lookup(key)))
        .toMap
    }

    RecalibrationTable(recalibrationQualityMappings)
  }

  private def computeQualityTable(
    globalEntry: (Int, scala.collection.Map[CovariateKey, Observation])): QualityTable = {
    QualityTable(globalEntry._2.groupBy(_._1.qualityScore).map(qualityEntry => {
      val extras = computeExtrasTable(qualityEntry._2)
      (qualityEntry._1, (aggregateObservations(qualityEntry._2), extras))
    }))
  }

  private def computeExtrasTable(
    table: scala.collection.Map[CovariateKey, Observation]): ExtrasTable = {
    def makeTable[T](fn: CovariateKey => T): scala.collection.Map[T, Aggregate] = {
      table.groupBy(kv => fn(kv._1)).map(extraEntry => {
        (extraEntry._1, aggregateObservations(extraEntry._2))
      }).map(identity)
    }

    ExtrasTable(makeTable[Int]((ck: CovariateKey) => ck.cycle),
      makeTable[(Char, Char)]((ck: CovariateKey) => ck.dinuc))
  }

  /**
   * Aggregates observations over a table.
   *
   * Assumes that it is called on a non-empty table, as it is called on the
   * output of a groupBy.
   *
   * @param observations The map of observations to group by.
   * @return Returns the aggregated base/mismatch observations over a group of
   *   observed bases.
   */
  private def aggregateObservations(
    observations: scala.collection.Map[CovariateKey, Observation]): Aggregate = {
    assert(observations.nonEmpty)
    observations.map(p => {
      val (oldKey, obs) = p
      Aggregate(oldKey, obs)
    }).reduce(_ + _)
  }
}

/**
 * A table containing all of the quality scores observed in a single read gorup.
 *
 * @param table A table mapping ASCII encoded quality scores to an error
 *   aggregate and to a table of extra covariates.
 */
private case class QualityTable(
    table: scala.collection.Map[Char, (Aggregate, ExtrasTable)]) {
}

/**
 * A table containing all of the extra covariates observed in a single quality
 * bucket for a single read group.
 *
 * @param cycleTable The error aggregates corresponding to a sequencer cycle.
 * @param dinucTable The error aggregates corresponding to an individual pair
 *   of nucleotides.
 */
private case class ExtrasTable(
    cycleTable: scala.collection.Map[Int, Aggregate],
    dinucTable: scala.collection.Map[(Char, Char), Aggregate]) {
}
